package org.stagemonitor.web.tracing;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpRequest;
import org.springframework.http.client.ClientHttpRequestExecution;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.http.converter.StringHttpMessageConverter;
import org.springframework.mock.http.client.MockClientHttpRequest;
import org.springframework.web.client.RestTemplate;
import org.stagemonitor.tracing.B3HeaderFormat;
import org.stagemonitor.tracing.TracingPlugin;
import org.stagemonitor.tracing.tracing.B3Propagator;
import org.stagemonitor.web.tracing.SpringRestTemplateContextPropagatingTransformer.SpringRestTemplateContextPropagatingInterceptor;
import java.net.URI;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import io.opentracing.mock.MockTracer;
import static org.hamcrest.CoreMatchers.hasItem;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.CoreMatchers.isA;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public class SpringRestTemplateContextPropagatingTransformerTest {
private TracingPlugin tracingPlugin;
@Before
public void setUp() throws Exception {
tracingPlugin = mock(TracingPlugin.class);
when(tracingPlugin.getTracer()).thenReturn(new MockTracer(new B3Propagator()));
}
@Test
public void testRestTemplateHasInterceptor() throws Exception {
final List<RestTemplate> restTemplates = Arrays.asList(
new RestTemplate(),
new RestTemplate(new SimpleClientHttpRequestFactory()),
new RestTemplate(Collections.singletonList(new StringHttpMessageConverter())));
for (RestTemplate restTemplate : restTemplates) {
assertThat(restTemplate.getInterceptors().size(), is(1));
assertThat(restTemplate.getInterceptors(), hasItem(isA(SpringRestTemplateContextPropagatingInterceptor.class)));
}
}
@Test
public void testB3HeaderContextPropagation() throws Exception {
HttpRequest httpRequest = new MockClientHttpRequest(HttpMethod.GET, new URI("http://example.com"));
new SpringRestTemplateContextPropagatingInterceptor(tracingPlugin)
.intercept(httpRequest, null, mock(ClientHttpRequestExecution.class));
assertTrue(httpRequest.getHeaders().containsKey(B3HeaderFormat.SPAN_ID_NAME));
assertTrue(httpRequest.getHeaders().containsKey(B3HeaderFormat.TRACE_ID_NAME));
}
}